# mypy: ignore-errors

import functools
import inspect
import logging
import operator
import textwrap
import types
import unittest
from typing import Dict, List

import sympy

import torch._numpy as tnp
import torch.fx
import torch.random
from torch._dynamo import compiled_autograd
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import (
    guard_scalar,
    GuardOnDataDependentSymNode,
    has_free_symbols,
    is_symbolic,
    SymTypes,
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented, UserError, UserErrorType
from ..external_utils import call_hook_from_backward_state
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource
from ..utils import (
    fqn,
    get_custom_getattr,
    get_fake_value,
    get_real_value,
    guard_if_dyn,
    object_has_getattribute,
    product,
    proxy_args_kwargs,
    set_example_value,
    tensortype_to_dtype,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable

try:
    import numpy as np
except ModuleNotFoundError:
    np = None

log = logging.getLogger(__name__)

# Ops that allow tensor <op> tensor
supported_tensor_comparison_ops = {
    ">": operator.gt,
    "<": operator.lt,
    ">=": operator.ge,
    "<=": operator.le,
    "==": operator.eq,
    "!=": operator.ne,
}
# Ops that allow tensor <op> None
supported_const_comparison_ops = {
    "is": operator.is_,
    "is not": operator.is_not,
    "==": operator.eq,
    "!=": operator.ne,
}
supported_comparison_ops = {
    **supported_tensor_comparison_ops,
    **supported_const_comparison_ops,
}
supported_tensor_comparison_op_values = dict.fromkeys(
    supported_tensor_comparison_ops.values()
)
supported_const_comparison_op_values = dict.fromkeys(
    supported_const_comparison_ops.values()
)


class TensorVariable(VariableTracker):
    """A torch.Tensor input or an intermediate value in the FX graph"""

    _nonvar_fields = {
        "proxy",
        "dtype",
        "device",
        "layout",
        "ndim",
        "size",
        "stride",
        "requires_grad",
        "is_quantized",
        "is_contiguous",
        "is_sparse",
        "class_type",
        "specialized_value",
        "_is_name_set",
        *VariableTracker._nonvar_fields,
    }

    def get_real_value(self):
        """
        Get the actual value represented by this variable if computation is run
        using the user-provided inputs.
        NOTE: this runs actual tensor computation and may be
        slow and memory-intensive.
        """
        return get_real_value(self.proxy.node, self.proxy.tracer)

    def __init__(
        self,
        proxy: torch.fx.Proxy,
        *,
        dtype,
        device,
        layout,
        ndim,
        requires_grad,
        is_quantized,
        is_sparse,
        class_type,
        has_grad_fn,
        size=None,
        stride=None,
        is_contiguous=None,
        _is_name_set=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.proxy = proxy
        self.dtype = dtype
        self.device = device
        self.layout = layout
        self.ndim = ndim
        self.size = size
        self.stride = stride
        self.requires_grad = requires_grad
        self.is_quantized = is_quantized
        self.is_contiguous = is_contiguous
        self.is_sparse = is_sparse
        self.class_type = class_type
        self.has_grad_fn = has_grad_fn
        if _is_name_set is None:
            # no need to rename inputs
            _is_name_set = self.proxy.node.op == "placeholder"
        self._is_name_set: bool = _is_name_set

    def debug_repr(self):
        # TODO: strip off fake tensor from repr here
        return repr(self.proxy.node.meta["example_value"])

    def as_proxy(self):
        return self.proxy

    def python_type(self):
        return self.class_type

    @staticmethod
    def specialize(value: torch.Tensor):
        props = {
            "dtype": value.dtype,
            "device": value.device,
            "layout": value.layout,
            "ndim": int(value.ndim),
            "requires_grad": value.requires_grad,
            "is_quantized": value.is_quantized,
            "is_sparse": value.is_sparse,
            "class_type": type(value),
        }
        try:
            props["has_grad_fn"] = value.grad_fn is not None
        except Exception:
            # Workaround for issues with create_parameter_op in Dynamo. Reading
            # grad_fn should never cause an issue.
            props["has_grad_fn"] = False

        if is_sparse_any(value) and not has_free_symbols(value):
            props["size"] = tuple(
                [int(s) if is_symbolic(s) else s for s in value.size()]
            )
        elif not has_free_symbols(value):
            # this is a fully static shape, and the keys on props here inform specialization.
            # We have to cast to int here, because these might get accessed as ConstantVariable, which has
            # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
            # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
            # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
            # I'd like to keep it around for now.
            props["size"] = tuple(
                # the non is_symbolic case applies to the jagged layout
                # NestedTensor case as singleton ints are not symbolic
                [int(s) if is_symbolic(s) else s for s in value.size()]
            )
            props["stride"] = tuple(value.stride())
            if torch._C._functorch.is_batchedtensor(value):
                # Batched tensors does not support contiguity patterns, so
                # we refrain from computing the `is_contiguous` property
                props["is_contiguous"] = None
            else:
                props["is_contiguous"] = tuple(
                    [
                        x
                        for x in torch._prims_common._memory_formats
                        if value.is_contiguous(memory_format=x)
                    ]
                )
        return props

    def dynamic_getattr(self, tx, name):
        fake_val = self.proxy.node.meta["example_value"]
        # For getattrs on tensors without sources,
        # we can do better than the default (creating a GetAttrVariable)
        # if:
        # (1) the tensor is a traceable tensor subclass
        # (2) We are getattr'ing an inner tensor from that subclass
        if not self.source and is_traceable_wrapper_subclass(fake_val):
            fake_val = self.proxy.node.meta["example_value"]
            attrs, ctx = fake_val.__tensor_flatten__()
            proxy = getattr(self.as_proxy(), name)
            example_value = getattr(fake_val, name)
            if name in attrs:
                # attrs returned from tensor_flatten are always tensors
                assert isinstance(example_value, torch.Tensor)
                from .builder import wrap_fx_proxy

                return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value)
            # any other attributes on the subclass (that are not methods)
            # are assumed to be constant metadata.
            elif not callable(example_value):
                from .builder import SourcelessBuilder

                return SourcelessBuilder.create(tx, example_value)

        if not (self.source and self.source.subguards_allowed()):
            raise NotImplementedError

        # For local source, we associate the real value. We use this real value
        # for implementing getattr fallthrough on the variable tracker base class.

        # Note - this scope construction is mirrored in guards
        # A subsequent PR will introduce a util.
        scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
        try:
            # We raise in case we get a typerror bug w/ SuperSource.
            # SuperSource has bugs in it atm, and can produce code like
            # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
            # L['mod'].model.model.encoder.embed_positions)", scope)
            # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
            _input_associated_real_value = eval(self.source.name(), scope)
        except Exception as exc:
            raise NotImplementedError from exc

        if _input_associated_real_value is None:
            raise NotImplementedError

        if object_has_getattribute(_input_associated_real_value):
            raise NotImplementedError

        if get_custom_getattr(_input_associated_real_value):
            raise NotImplementedError

        real_value = getattr(_input_associated_real_value, name)
        if callable(real_value):
            # Callables have more nuanced handling, and we should let the existing system delegate here.
            # Raising was past behavior and so should always be sound to fall back.
            # Note - at a certain point we may want to handle
            raise NotImplementedError

        from ..guards import GuardBuilder
        from .builder import VariableBuilder

        attr_source = AttrSource(self.source, name)
        install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
        return VariableBuilder(tx, attr_source)(real_value)

    def method_attr_ndim(self, tx):
        if self.ndim is not None:
            return ConstantVariable.create(self.ndim)
        else:
            return self.call_method(tx, "dim", [], {})

    def method_attr_dtype(self, tx):
        if self.dtype is not None:
            return ConstantVariable.create(self.dtype)

    def method_attr_device(self, tx):
        if self.device is not None:
            return ConstantVariable.create(self.device)

    def method_attr_layout(self, tx):
        if self.layout is not None:
            return ConstantVariable.create(self.layout)

    def method_attr_is_cuda(self, tx):
        if self.device is not None:
            return ConstantVariable.create(self.device.type == "cuda")

    def method_attr_shape(self, tx):
        if self.size is not None:
            sizes = [variables.ConstantVariable.create(x) for x in self.size]
            return SizeVariable(sizes)
        else:
            return self.call_method(tx, "size", [], {})

    def method_attr_requires_grad(self, tx):
        if self.requires_grad is not None:
            return ConstantVariable.create(self.requires_grad)

    def method_attr_is_quantized(self, tx):
        if self.is_quantized is not None:
            return ConstantVariable.create(self.is_quantized)

    def method_attr_is_sparse(self, tx):
        if self.is_sparse is not None:
            return ConstantVariable.create(self.is_sparse)

    def method_attr_data(self, tx):
        return self.call_method(tx, "detach", [], {})

    def method_attr_grad_fn(self, tx):
        if self.has_grad_fn:
            unimplemented("TensorVariable has a grad_fn")
        else:
            return variables.ConstantVariable(None)

    def method_attr__version(self, tx):
        from ..tensor_version_op import _tensor_version

        return variables.TorchInGraphFunctionVariable(_tensor_version).call_function(
            tx, [self], {}
        )

    def var_getattr(self, tx, name):
        from . import UserDefinedClassVariable

        if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
            unimplemented(f"Illegal getattr invocation {name} in strict mode")

        if name == "__class__":
            return UserDefinedClassVariable(self.python_type())

        handler = getattr(self, f"method_attr_{name}", None)
        result = handler(tx) if handler is not None else None

        # Add a guard for type matching, these guards are checked before tensor guards
        # In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
        # <tensor> is later changed to another type
        if (
            result is not None
            and self.source
            and self.source.subguards_allowed()
            and not (
                name not in ("grad", "requires_grad") and result.is_python_constant()
            )
        ):
            install_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
            result.source = AttrSource(self.source, name)

        # It's hard to get inplace view (metadata mutation) on graph input work properly across
        # dynamo/aot/inductor, just fall back.
        if self.source is not None and hasattr(torch.ops.aten, name):
            fn = getattr(torch.ops.aten, name)
            if (
                hasattr(fn, "overloads")
                and hasattr(fn, fn.overloads()[0])
                and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
            ):
                # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
                return variables.misc.DelayGraphBreakVariable(
                    source=AttrSource(self.source, name)
                )

        # For attributes (not methods) that were not caught in the special handling above,
        # (e.g. tensor.real), we handle these generically, assuming that the output type is
        # a tensor.
        if result is None and name != "grad":

            def try_generic_attr_handling():
                from .builder import wrap_fx_proxy
                from .misc import GetAttrVariable

                try:
                    static_attr = inspect.getattr_static(torch.Tensor, name)
                except AttributeError:
                    return None

                # Make sure this is an attribute, not a method.
                # type(torch.Tensor.H) should be "getset_descriptor"
                # This is a because of CPython implementation, see THPVariableType:
                # these attributes are implemented under tp_getset, which appear
                # as `getset_descriptor`s, (compared to, say, methods which appear
                # as `method_descriptor`s)
                if type(static_attr) != types.GetSetDescriptorType:
                    return None

                proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
                if self.source is not None:
                    return wrap_fx_proxy(
                        tx=tx, proxy=proxy, source=AttrSource(self.source, name)
                    )
                else:
                    return wrap_fx_proxy(tx=tx, proxy=proxy)

            result = try_generic_attr_handling()

        if result is None:
            result = self.dynamic_getattr(tx, name)

        if result is None:
            raise NotImplementedError
        return result

    def has_unpack_var_sequence(self, tx):
        return self.ndim > 0

    def unpack_var_sequence(self, tx, idxes=None):
        from .builder import wrap_fx_proxy_cls

        if self.size:
            size_len = len(self.size)
        else:
            size_var = self.call_method(tx, "size", [], {})
            assert isinstance(size_var, SizeVariable)
            size_len = len(size_var.items)
        # Ensure we don't unpack a scalar tensor.
        assert size_len != 0, "Can't unpack scalar tensors."

        if self.size:
            length = self.size[0]
        else:
            dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {})
            # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through
            # symbolic_shapes, but that end up as int/sympy.Integer
            assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
            if isinstance(dyn_length, SymNodeVariable):
                length = dyn_length.evaluate_expr(tx.output)
            else:
                length = dyn_length.value

        if idxes is None:
            idxes = range(length)
        else:
            assert (
                len(idxes) == length
            ), f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements."
        return [
            wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i])
            for i in idxes
        ]

    def _strict_mode_banned_ops(self):
        return torch._dynamo.config._autograd_backward_strict_mode_banned_ops

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
            unimplemented(f"Illegal method invocation {name} in strict mode")

        """
        Dispatch to a method-specific handler defined below.  If the
        handler returns None (or doesn't exist) we put the method call
        in the graph.
        """
        try:
            handler_method = getattr(self, f"method_{name}")
        except AttributeError:
            pass
        else:
            try:
                result = handler_method(*args, **kwargs)
                if result:
                    return result
            except TypeError as e:
                unimplemented(f"unhandled args for {name}: {e}")

        from .builder import wrap_fx_proxy

        return wrap_fx_proxy(
            tx,
            tx.output.create_proxy(
                "call_method",
                name,
                *proxy_args_kwargs([self, *args], kwargs),
            ),
        )

    def method_size(self, *args, **kwargs):
        return self._method_size_stride("size", *args, **kwargs)

    def method_stride(self, *args, **kwargs):
        return self._method_size_stride("stride", *args, **kwargs)

    def _method_size_stride(self, name, dim=None):
        dim = guard_if_dyn(dim)

        def make_const_size_variable(x, **options):
            return SizeVariable(
                [ConstantVariable.create(y, **options) for y in x], **options
            )

        RetVariable = (
            make_const_size_variable if name == "size" else ConstantVariable.create
        )

        # Technically, this should not be necessary, but I'm including it
        # for enhanced BC, in case example_value is sometimes not set
        # (it really should always be set though!)
        if (r := getattr(self, name)) is not None:
            if dim is None:
                return RetVariable(r)
            else:
                return ConstantVariable.create(r[dim])

        # It might still be constant!  Consult the fake tensor and see
        if (fake := self.proxy.node.meta.get("example_value")) is not None:
            if dim is None:
                fake_r = getattr(fake, name)()
                if not has_free_symbols(fake_r):
                    # int conversion for safety, in case a SymInt refined
                    # to constant
                    return RetVariable(tuple(int(r) for r in fake_r))
            else:
                fake_r = getattr(fake, name)(dim)
                if not has_free_symbols(fake_r):
                    return ConstantVariable.create(int(fake_r))

    def method_numel(self):
        if self.size is not None:
            return ConstantVariable.create(product(self.size))

        # It might still be constant!  Consult the fake tensor and see
        if (fake := self.proxy.node.meta.get("example_value")) is not None:
            fake_r = fake.numel()
            if not has_free_symbols(fake_r):
                return ConstantVariable.create(int(fake_r))

    method_nelement = method_numel

    def method_dim(self):
        if self.ndim is not None:
            return ConstantVariable.create(self.ndim)

    method_ndimension = method_dim

    def method_is_floating_point(self):
        if self.dtype is not None:
            return ConstantVariable.create(self.dtype.is_floating_point)

    def method_is_complex(self):
        if self.dtype is not None:
            return ConstantVariable.create(self.dtype.is_complex)

    def method_is_contiguous(self, memory_format=None):
        memory_format = (
            memory_format.as_python_constant()
            if memory_format is not None
            else torch.contiguous_format
        )
        if self.is_contiguous is not None:
            return ConstantVariable.create(memory_format in self.is_contiguous)
        elif (fake := self.proxy.node.meta.get("example_value")) is not None:
            return ConstantVariable.create(
                fake.is_contiguous(memory_format=memory_format)
            )

    def method_type(self, dtype=None, non_blocking=False, **kwargs):
        if (
            dtype is None
            and self.dtype is not None
            and isinstance(self.device, torch.device)
        ):
            tensortype = next(
                k for k, v in tensortype_to_dtype.items() if self.dtype in v
            )
            if self.device.type == "cuda":
                return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}")
            else:
                return ConstantVariable.create(f"torch.{tensortype.__name__}")
        elif (
            dtype is not None
            and fqn(type(dtype.as_python_constant())) == "torch.tensortype"
        ):
            # torch.FloatTensor, etc. are all of type "torch.tensortype".
            # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
            # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
            tensor_type = dtype.as_python_constant()
            tensor_type_const = ConstantVariable.create(fqn(tensor_type))

            from ..symbolic_convert import InstructionTranslator
            from .builder import wrap_fx_proxy

            tx = InstructionTranslator.current_tx()

            if non_blocking:
                kwargs = {"non_blocking": non_blocking, **kwargs}

            return wrap_fx_proxy(
                tx,
                tx.output.create_proxy(
                    "call_method",
                    "type",
                    *proxy_args_kwargs([self, tensor_type_const], kwargs),
                ),
            )

    def method_as_subclass(self, cls):
        if isinstance(cls, TensorSubclassVariable) and cls.source:
            from ..symbolic_convert import InstructionTranslator
            from .builder import VariableBuilder
            from .torch_function import TensorWithTFOverrideVariable

            tx = InstructionTranslator.current_tx()

            # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable
            # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass
            # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
            # It is up to the user whether this is correct behavior or not.
            py_cls = cls.as_python_constant()
            torch_fn = VariableBuilder(
                tx,
                AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
            )(py_cls.__torch_function__.__func__)

            return TensorWithTFOverrideVariable.from_tensor_var(
                tx, self, py_cls, torch_fn
            )

    def method_get_device(self):
        if isinstance(self.device, torch.device):
            index = self.device.index if self.device.type != "cpu" else -1
            return ConstantVariable.create(index)

    def method_element_size(self):
        return ConstantVariable.create(self.dtype.itemsize)

    def method_numpy(self, *, force=False):
        if not config.trace_numpy:
            unimplemented("Tensor.numpy(). config.trace_numpy is False")
        if not np:
            unimplemented("Tensor.numpy(). NumPy is not available")
        if self.layout != torch.strided:
            raise TypeError(
                f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
            )
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()

        # We don't check that the tensor is on CPU when force is False, as this
        # allows us to execute NumPy code on CUDA. Same for requires_grad=True
        if force and force.as_python_constant():
            # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...)
            t = self.call_method(tx, "detach", [], {})
            proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {})
        else:
            # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable
            proxy = tx.output.create_proxy(
                "call_method", "view_as", *proxy_args_kwargs([self, self], {})
            )
        return NumpyNdarrayVariable.create(tx, proxy)

    def method_tolist(self):
        from ..symbolic_convert import InstructionTranslator
        from .builder import SourcelessBuilder

        tx = InstructionTranslator.current_tx()

        def tolist(tensor, sub_proxy):
            def wrap(i, sub_proxy):
                # Sigh, we forgot to gate this, so this data dependent is on
                # by default and is load bearing in CI
                with unittest.mock.patch.object(
                    tx.fake_mode, "allow_scalar_outputs", True
                ):
                    return SymNodeVariable.create(
                        tx,
                        sub_proxy.item(),
                    )

            if tensor.dtype not in [
                torch.int8,
                torch.int16,
                torch.int32,
                torch.int64,
            ]:
                unimplemented("Input tensor for tolist must be an integer tensor")

            if tensor.dim() == 0:
                return wrap(tensor, sub_proxy)

            if tensor.dim() == 1:
                return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]

            return [
                tolist(sub_tensor, sub_proxy=sub_proxy[i])
                for i, sub_tensor in enumerate(tensor)
            ]

        tensor = self.as_proxy().node.meta["example_value"]
        out = tolist(tensor, self.as_proxy())
        return SourcelessBuilder.create(tx, out)

    def method_backward(self, *args, **kwargs):
        unimplemented("Tensor.backward")

    def method_data_ptr(self, *args, **kwargs):
        unimplemented("Tensor.data_ptr")

    def method_item(self, *args, **kwargs):
        if not config.capture_scalar_outputs:
            self._warn_capture_scalar_outputs()
            unimplemented("Tensor.item")

    @staticmethod
    @functools.lru_cache(None)
    def _warn_capture_scalar_outputs():
        log.warning(
            textwrap.dedent(
                """\
                    Graph break from `Tensor.item()`, consider setting:
                        torch._dynamo.config.capture_scalar_outputs = True
                    or:
                        env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
                    to include these operations in the captured graph.
                """
            )
        )

    def method___len__(self):
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()
        return self.call_method(tx, "size", [ConstantVariable.create(0)], {})

    def method___setitem__(self, key, value):
        def has_bool_key(v):
            if isinstance(v, TensorVariable):
                return v.dtype in (torch.bool, torch.int8)
            elif isinstance(v, variables.TupleVariable):
                return any(has_bool_key(item) for item in v.items)
            else:
                return False

        if (
            has_bool_key(key)
            and isinstance(value, TensorVariable)
            and value.requires_grad
            and torch.is_grad_enabled()
        ):
            unimplemented(
                "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123"
            )
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()
        tx.output.create_proxy(
            "call_function",
            operator.setitem,
            *proxy_args_kwargs([self, key, value], {}),
        )
        return ConstantVariable.create(None)

    def method_resize_(self, *args, **kwargs):
        unimplemented("Tensor.resize_")

    def method_resize_as_(self, *args, **kwargs):
        unimplemented("Tensor.resize_as_")

    def method_sparse_resize_(self, *args, **kwargs):
        unimplemented("Tensor.sparse_resize_")

    def method_sparse_resize_and_clear_(self, *args, **kwargs):
        unimplemented("Tensor.sparse_resize_and_clear_")

    def method_set_(self, *args, **kwargs):
        if len(args) > 1:
            # torch.Tensor.set_() has several overloads.
            # aten::set_.source_Tensor(Tensor) gets special handling
            # in AOTAutograd and functionalization, because it is the most common
            # overload and is used by FSDP.
            # graph-breaking on aten::set_source_Tensor_storage_offset for now,
            # unless we find that we need to make it work.
            unimplemented("Tensor.set_.source_Tensor_storage_offset")

    def method_add_(self, other, *, alpha=None):
        if alpha is not None:
            from ..symbolic_convert import InstructionTranslator

            tx = InstructionTranslator.current_tx()
            result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
                tx, [other, alpha], {}
            )
            return self.call_method(tx, "add_", [result], {})

    def method_addcdiv_(self, tensor1, tensor2, *, value=None):
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()
        if value is not None:
            result = variables.TorchInGraphFunctionVariable(torch.div).call_function(
                tx, [tensor1, tensor2], {}
            )
            result = variables.TorchInGraphFunctionVariable(torch.mul).call_function(
                tx, [result, value], {}
            )
            return self.call_method(tx, "add_", [result], {})

    def method___contains__(self, arg):
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()

        # Rewrite __contains__ here so that downstream passes can trace through
        # without dealing with unbacked symbool. Roughly the code we translate is:
        # def __contains__(self, x):
        #     return (x == self).any().item()
        result = variables.TorchInGraphFunctionVariable(torch.eq).call_function(
            tx, [self, arg], {}
        )
        result = variables.TorchInGraphFunctionVariable(torch.any).call_function(
            tx, [result], {}
        )
        return result.call_method(tx, "item", [], {})

    def method_redistribute(self, *args, **kwargs):
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()
        # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
        # and rewrite args to have only proxyable args, then insert call_function
        args_as_value = [x.as_python_constant() for x in args]
        kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}

        def redistribute_fn_with_prim_types(x):
            return x.redistribute(*args_as_value, **kwargs_as_value)

        # attach the same function name for better debugging
        redistribute_fn_with_prim_types.__name__ = "prim_redistribute"

        from .builder import wrap_fx_proxy

        return wrap_fx_proxy(
            tx=tx,
            proxy=tx.output.create_proxy(
                "call_function",
                redistribute_fn_with_prim_types,
                *proxy_args_kwargs([self], {}),
            ),
        )

    def method_to_local(self, *args, **kwargs):
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()
        # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
        # and rewrite args to have only proxyable args, then insert call_function
        args_as_value = [x.as_python_constant() for x in args]
        kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}

        def to_local_fn_with_prim_types(x):
            return x.to_local(*args_as_value, **kwargs_as_value)

        # attach the same function name for better debugging
        to_local_fn_with_prim_types.__name__ = "prim_to_local"

        from .builder import wrap_fx_proxy

        return wrap_fx_proxy(
            tx=tx,
            proxy=tx.output.create_proxy(
                "call_function",
                to_local_fn_with_prim_types,
                *proxy_args_kwargs([self], {}),
            ),
        )

    def method_register_hook(self, *args, **kwargs):
        return self._method_register_hook("register_hook", *args, **kwargs)

    def method_register_post_accumulate_grad_hook(self, *args, **kwargs):
        return self._method_register_hook(
            "register_post_accumulate_grad_hook", *args, **kwargs
        )

    def _method_register_hook(self, name: str, hook: VariableTracker):
        # Note - do not arbitrarily add hooks here - make sure they match the same contract
        # see [On tensor.register_hook]
        from ..symbolic_convert import InstructionTranslator

        tx = InstructionTranslator.current_tx()

        if not self.source:
            if not compiled_autograd.compiled_autograd_enabled:
                # TODO(voz):
                # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
                # python state.
                # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run
                # them in a compiled bwd without re-entering dynamo as compiled_autograd does.
                #
                # Discussion point 1 - Should we bypass this if nopython/fullgraph = True?
                #   No. Because this was going to be a graph break anyway - this check does not
                # introduce new graph breaks where there were none.
                #
                # Discussion point 2 - Should we defer this check to backwards?
                #   No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user
                # would have no recourse - their forward traces just fine, but will fail at backwards unless
                # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today)
                # then they have nothing they can do except disable compile.
                unimplemented(
                    "Compilation of intermediate hooks requires compiled autograd"
                )

            hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook)

            def _register_hook_trampoline(tensor, bw_state):
                register_hook = getattr(tensor, name)
                register_hook(
                    functools.partial(
                        trace_wrapped,
                        fn=call_hook_from_backward_state,
                        bw_state=bw_state,
                        hook_name=hook_name,
                    )
                )
                # TODO(jansel): returning None here is wrong, it should be
                # RemovableHandle, but we need some extra work to support
                # this properly.
                return None

            from .builder import wrap_fx_proxy

            return wrap_fx_proxy(
                tx,
                tx.output.create_proxy(
                    "call_function",
                    _register_hook_trampoline,
                    (self.as_proxy(), bw_state_proxy),
                    {},
                ),
            )

        handle_variable = variables.RemovableHandleVariable(
            mutable_local=variables.base.MutableLocal(),
        )
        tx.output.side_effects.register_hook(self, hook, handle_variable, name)
        return handle_variable

    def method_requires_grad_(self, requires_grad=True):
        if requires_grad is not True:
            requires_grad = requires_grad.as_python_constant()

        if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad:
            unimplemented("Tensor.requires_grad_")
        else:
            return self

    def method_new(self, *args, **kwargs):
        # Convert x.new(torch.Size) into x.new_empty(torch.Size),
        # as Tensor.new acts differently with a Size input versus a tuple input.
        if (len(args) == 1 and isinstance(args[0], SizeVariable)) or (
            len(args) >= 1
            and all(
                isinstance(a, ConstantVariable) and a.python_type() == int for a in args
            )
        ):
            from ..symbolic_convert import InstructionTranslator

            return self.call_method(
                InstructionTranslator.current_tx(), "new_empty", args, kwargs
            )

    def method_untyped_storage(self):
        return UntypedStorageVariable(
            self, self.as_proxy().node.meta["example_value"].untyped_storage()
        )

    def set_name_hint(self, name: str):
        if not self._is_name_set:
            self.proxy.node._rename(name)
            self._is_name_set = True


class SymNodeVariable(VariableTracker):
    """
    Represents a symbolic scalar, either int, float or bool.  This is most commonly used to
    handle symbolic size computation, e.g., tensor.size(0), but it is also used to
    handle logic like float_tensor.item() or unspecialized float inputs.
    """

    _nonvar_fields = {
        "proxy",
        "sym_num",
        *VariableTracker._nonvar_fields,
    }

    def debug_repr(self):
        return repr(self.sym_num)

    @classmethod
    def create(cls, tx, proxy, sym_num=None, **options):
        if sym_num is None:
            sym_num = get_fake_value(proxy.node, tx)
        if "example_value" in proxy.node.meta:
            assert proxy.node.meta["example_value"] == sym_num
        set_example_value(proxy.node, sym_num)

        if isinstance(sym_num, (sympy.Integer, int, bool)):
            sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num
            return ConstantVariable.create(sym_num)

        return SymNodeVariable(proxy, sym_num, **options)

    def __init__(self, proxy, sym_num, **kwargs):
        super().__init__(**kwargs)
        self.proxy = proxy
        # TODO: Should we allow non SymTypes here?  Today it is allowed
        self.sym_num = sym_num
        self._tensor_var = None

    def python_type(self):
        if isinstance(self.sym_num, SymTypes):
            return self.sym_num.node.pytype
        else:
            return type(self.sym_num)

    def as_proxy(self):
        return self.proxy

    def as_tensor(self, tx):
        if self._tensor_var is None:
            from .builder import SourcelessBuilder

            self._tensor_var = SourcelessBuilder.create(
                tx, torch.scalar_tensor
            ).call_function(tx, [self], {})
        return self._tensor_var

    def evaluate_expr(self, output_graph=None):
        try:
            return guard_scalar(self.sym_num)
        except GuardOnDataDependentSymNode as e:
            raise UserError(  # noqa: B904
                UserErrorType.ANTI_PATTERN,
                f"Consider annotating your code using torch._check*(). {str(e)}",
                case_name="constrain_as_size_example",
            )

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        from .builder import wrap_fx_proxy

        return wrap_fx_proxy(
            tx,
            tx.output.create_proxy(
                "call_method",
                name,
                *proxy_args_kwargs([self, *args], kwargs),
            ),
        )


class NumpyNdarrayVariable(TensorVariable):
    """
    Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray.
    Use this for Tensor.numpy() call.
    """

    @staticmethod
    def create(tx, proxy, **options):
        from .builder import wrap_fx_proxy_cls

        return wrap_fx_proxy_cls(
            target_cls=NumpyNdarrayVariable,
            tx=tx,
            proxy=proxy,
            **options,
        )

    def var_getattr(self, tx, name):
        # NB: This INTENTIONALLY does not call super(), because there is
        # no intrinsic reason ndarray properties are related to Tensor
        # properties.  The inheritance here is for implementation sharing.

        from ..utils import numpy_attr_wrapper
        from .builder import wrap_fx_proxy

        result = None

        example_value = self.as_proxy().node.meta["example_value"]
        example_ndarray = tnp.ndarray(example_value)

        def insert_into_graph():
            return wrap_fx_proxy(
                tx,
                tx.output.create_proxy(
                    "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
                ),
            )

        if name in ["T", "real", "imag"]:
            proxy = tx.output.create_proxy(
                "call_function",
                numpy_attr_wrapper,
                (self.as_proxy(), name),
                {},
            )
            result = NumpyNdarrayVariable.create(tx, proxy)

        # These are awkward to implement.  The standard playbook for torch._numpy
        # interop is to trace a call into the torch._numpy wrapper which works for
        # Tensor operations.  However, we don't want to do this for calls
        # that don't return Tensors, because in those cases we may not want
        # to trace the attribute access into the graph at all (it is sort
        # of harmless to do so, because AOTAutograd will eliminate them,
        # but it's best not to trace them in to begin with.)  But in any
        # case, tracing these into the graph is like trying to fit a square
        # peg into a round hole; best not to do it.  So instead we
        # painstakingly implement these by hand
        #
        # NB: only ALWAYS specialized attributes can go here; notably,
        # size/shape not allowed!
        elif name in ("ndim", "itemsize"):
            return ConstantVariable.create(getattr(example_ndarray, name))
        elif name in ("shape", "stride"):
            if not has_free_symbols(r := getattr(example_ndarray, name)):
                return ConstantVariable.create(tuple(int(r) for r in r))
            return insert_into_graph()
        elif name == "size":
            if not has_free_symbols(r := example_ndarray.size):
                return ConstantVariable.create(int(r))
            return insert_into_graph()
        elif name in ["base", "flags", "dtype"]:
            unimplemented(f"TODO: add support for ndarray.{name}")
        elif name in ["__version__"]:
            unimplemented("delegate np.__version__ to NumPy")
        if result is None:
            raise NotImplementedError
        return result

    @staticmethod
    def patch_args(name, args, kwargs):
        if name == "clip":
            kwargs_rename = {"a_min": "min", "a_max": "max"}
            kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()}
        return args, kwargs

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        from ..utils import numpy_method_wrapper

        args, kwargs = self.patch_args(name, args, kwargs)

        if name in ["__len__", "size", "tolist"]:
            # delegate back to TensorVariable
            return super().call_method(tx, name, args, kwargs)
        if name == "tobytes":
            unimplemented("tobytes is not modelled in torch._numpy")
        proxy = tx.output.create_proxy(
            "call_function",
            numpy_method_wrapper(name),
            *proxy_args_kwargs([self] + list(args), kwargs),
        )
        return NumpyNdarrayVariable.create(tx, proxy)

    def python_type(self):
        return np.ndarray


class UnspecializedPythonVariable(TensorVariable):
    """
    This is a 1-element tensor represents unspecialized python float/int.
    """

    _nonvar_fields = {
        "raw_value",
        "need_unwrap",
        *TensorVariable._nonvar_fields,
    }

    def __init__(
        self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs
    ):
        super().__init__(proxy, **kwargs)
        self.raw_value = raw_value
        self.need_unwrap = need_unwrap

    @classmethod
    def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
        # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
        return UnspecializedPythonVariable(
            **dict(tensor_variable.__dict__),
            raw_value=raw_value,
            need_unwrap=need_unwrap,
        )


class FakeItemVariable(TensorVariable):
    """An unspecialized python variable which prevents access to the underlying raw value.
    This is needed if item is called on a FakeTensor."""

    _nonvar_fields = {
        "need_unwrap",
        *TensorVariable._nonvar_fields,
    }

    def __init__(self, proxy: torch.fx.Proxy, **kwargs):
        need_unwrap = kwargs.pop("need_unwrap", False)
        super().__init__(proxy, **kwargs)
        self.need_unwrap = need_unwrap

    @classmethod
    def from_tensor_variable(cls, tensor_variable):
        return FakeItemVariable(**dict(tensor_variable.__dict__))


class TensorSubclassVariable(VariableTracker):
    def __init__(self, value, *args, **kwargs):
        self.value = value
        super().__init__(*args, **kwargs)

    def call_function(
        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
    ) -> VariableTracker:
        if len(args) == 1 and isinstance(args[0], TensorVariable):
            from .builder import VariableBuilder
            from .torch_function import TensorWithTFOverrideVariable

            torch_fn = VariableBuilder(
                tx, AttrSource(self.source, "__torch_function__")
            )(self.value.__torch_function__)

            return TensorWithTFOverrideVariable.from_tensor_var(
                tx, args[0], self.value, torch_fn
            )

        return super().call_function(tx, args, kwargs)

    def as_python_constant(self):
        return self.value

    def python_type(self):
        return type(self.value)


class UntypedStorageVariable(VariableTracker):
    _nonvar_fields = {
        "example_value",
        *VariableTracker._nonvar_fields,
    }

    def __init__(
        self,
        from_tensor: TensorVariable,
        example_value: torch.UntypedStorage,
        **kwargs,
    ):
        super().__init__(**kwargs),
        self.from_tensor = from_tensor
        # Example_value will always have device="meta"
        self.example_value = example_value

    def call_method(
        self,
        tx,
        name,
        args: List[VariableTracker],
        kwargs: Dict[str, VariableTracker],
    ) -> VariableTracker:
        if name == "size":
            assert not args
            assert not kwargs
            result = self.example_value.size()
            if not has_free_symbols(result):
                # avoid creating a node in the graph
                return ConstantVariable.create(int(result))
            else:
                from ..external_utils import untyped_storage_size
                from .builder import wrap_fx_proxy

                return wrap_fx_proxy(
                    tx,
                    tx.output.create_proxy(
                        "call_function",
                        untyped_storage_size,
                        (self.from_tensor.as_proxy(),),
                        {},
                    ),
                )
        if name == "resize_" and len(args) == 1:
            assert not kwargs
            tx.output.create_proxy(
                "call_function",
                torch.ops.inductor.resize_storage_bytes_,
                (self.from_tensor.as_proxy(), args[0].as_proxy()),
                {},
            )
            return self

        return super().call_method(tx, name, args, kwargs)

    def reconstruct(self, codegen):
        codegen(self.from_tensor)
        codegen.load_method("untyped_storage")
        codegen.call_method(0)
